-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scaled_dot_product_attention api #55242
scaled_dot_product_attention api #55242
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
✅ This PR's description meets the template requirements! |
Sorry to inform you that d419bc7's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
… develop_scaleed_dot_product_attention_api
… develop_scaleed_dot_product_attention_api
|
||
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. | ||
The dimensions of the three parameters are the same. | ||
``d`` represents the size of the last dimension of the three parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, Q
, K
, and V
denote the three input parameters of the attention module, all sharing identical dimensions. d
represents the size of the last dimension of these three parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在数学公式里面, 一般用 where
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK,我这个是用的ChatGPT做的改动,仅供参考。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
The dtype can be float16 or bfloat16. | ||
|
||
Examples: | ||
.. code-block:: python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
框架正在引入xdoctest,示例代码可以顺便改成xdoctest支持的格式,see #55295
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xdoctest支持的格式是什么样的呢?
是否有个 demo 或明确的规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请参看我给的PR里的改动。
… develop_scaleed_dot_product_attention_api
… develop_scaleed_dot_product_attention_api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… develop_scaleed_dot_product_attention_api
… develop_scaleed_dot_product_attention_api
… develop_scaleed_dot_product_attention_api
…ithub.com/liuzhenhai93/Paddle into develop_scaleed_dot_product_attention_api
@@ -407,4 +407,57 @@ def flash_attn_unpadded( | |||
return out, softmax if return_softmax else None | |||
|
|||
|
|||
scaled_dot_product_attention = flash_attention | |||
def scaled_dot_product_attention( | |||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to be consistent with other APIs, there must be a parameter name=None
at last
>>> print(output) | ||
>>> # xdoctest: -SKIP | ||
""" | ||
assert attn_mask is None, "attn_mask is not supported yet" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If attn_mask
is not currently supported, add a TODO statement to indicate that it will be supported later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经有工作正在支持attn_mask,因此依赖当前PR合入。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
PR types
Others
PR changes
APIs
Description
scaled_dot_product_attention api
card-72806